{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v1CUZ0dkOo_F"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "qmkj-80IHxnd"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_xnMOsbqHz61"
      },
      "source": [
        "# pix2pix: Image-to-image translation with a conditional GAN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ds4o1h4WHz9U"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/generative/pix2pix\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/generative/pix2pix.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/pix2pix.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/generative/pix2pix.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ITZuApL56Mny"
      },
      "source": [
        "This tutorial demonstrates how to build and train a conditional generative adversarial network (cGAN) called pix2pix that learns a mapping from input images to output images, as described in [Image-to-image translation with conditional adversarial networks](https://arxiv.org/abs/1611.07004) by Isola et al. (2017). pix2pix is not application specific—it can be applied to a wide range of tasks, including synthesizing photos from label maps, generating colorized photos from black and white images, turning Google Maps photos into aerial images, and even transforming sketches into photos.\n",
        "\n",
        "In this example, your network will generate images of building facades using the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/) provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep it short, you will use a [preprocessed copy]((https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/)) of this dataset created by the pix2pix authors.\n",
        "\n",
        "In the pix2pix cGAN, you condition on input images and generate corresponding output images. cGANs were first proposed in [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784) (Mirza and Osindero, 2014)\n",
        "\n",
        "The architecture of your network will contain:\n",
        "\n",
        "- A generator with a [U-Net]([U-Net](https://arxiv.org/abs/1505.04597))-based architecture.\n",
        "- A discriminator represented by a convolutional PatchGAN classifier (proposed in the [pix2pix paper](https://arxiv.org/abs/1611.07004)).\n",
        "\n",
        "Note that each epoch can take around 15 seconds on a single V100 GPU.\n",
        "\n",
        "Below are some examples of the output generated by the pix2pix cGAN after training for 200 epochs.\n",
        "\n",
        "![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)\n",
        "![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e1_Y75QXJS6h"
      },
      "source": [
        "## Import TensorFlow and other libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YfIk2es3hJEd"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import os\n",
        "import time\n",
        "import datetime\n",
        "\n",
        "from matplotlib import pyplot as plt\n",
        "from IPython import display"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iYn4MdZnKCey"
      },
      "source": [
        "## Load the dataset\n",
        "\n",
        "Let's load the CMP Facade Database data. (You can try other preprocessed pix2pix-related datasets [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets).)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kn-k8kTXuAlv"
      },
      "outputs": [],
      "source": [
        "_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'\n",
        "\n",
        "path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n",
        "                                      origin=_URL,\n",
        "                                      extract=True)\n",
        "\n",
        "PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1fUzsnerj1P3"
      },
      "source": [
        "Each original image is of size `256 x 512` containing two `256 x 256` images:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XGY1kiptguTQ"
      },
      "outputs": [],
      "source": [
        "sample_image = tf.io.read_file(PATH + 'train/1.jpg')\n",
        "sample_image = tf.io.decode_jpeg(sample_image)\n",
        "print(sample_image.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vJ2sO8Izg7QV"
      },
      "outputs": [],
      "source": [
        "plt.figure()\n",
        "plt.imshow(sample_image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2A5SU-qxPAqd"
      },
      "source": [
        "You need to separate real building facade images from the architecture label images—all of which will be of size `256 x 256`.\n",
        "\n",
        "Define a function that loads image files and outputs two image tensors:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aO9ZAGH5K3SY"
      },
      "outputs": [],
      "source": [
        "def load(image_file):\n",
        "  # Read and decode an image file to a uint8 tensor\n",
        "  image = tf.io.read_file(image_file)\n",
        "  image = tf.image.decode_jpeg(image)\n",
        "\n",
        "  # Split each image tensor into two tensors:\n",
        "  # - one with a real building facade image\n",
        "  # - one with an architecture label image \n",
        "  w = tf.shape(image)[1]\n",
        "  w = w // 2\n",
        "  input_image = image[:, w:, :]\n",
        "  real_image = image[:, :w, :]\n",
        "\n",
        "  # Convert both images to float32 tensors\n",
        "  input_image = tf.cast(input_image, tf.float32)\n",
        "  real_image = tf.cast(real_image, tf.float32)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r5ByHTlfE06P"
      },
      "source": [
        "Plot a sample of the input (architecture label image) and real (building facade photo) images:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4OLHMpsQ5aOv"
      },
      "outputs": [],
      "source": [
        "inp, re = load(PATH + 'train/100.jpg')\n",
        "# Casting to int for matplotlib to display the images\n",
        "plt.figure()\n",
        "plt.imshow(inp / 255.0)\n",
        "plt.figure()\n",
        "plt.imshow(re / 255.0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PVuZQTfI_c-s"
      },
      "source": [
        "As described in the [pix2pix paper](https://arxiv.org/abs/1611.07004), you need to apply random jittering and mirroring to preprocess the training set.\n",
        "\n",
        "Define several functions that:\n",
        "\n",
        "1. Resize each `256 x 256` image to a larger height and width—`286 x 286`.\n",
        "2. Randomly crop it back to `256 x 256`.\n",
        "3. Randomly flip the image horizontally i.e. left to right (random mirroring).\n",
        "4. Normalize the images to the `[-1, 1]` range."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2CbTEt448b4R"
      },
      "outputs": [],
      "source": [
        "# The facade training set consist of 400 images\n",
        "BUFFER_SIZE = 400\n",
        "# The batch size of 1 produced better results for the U-Net in the original pix2pix experiment\n",
        "BATCH_SIZE = 1\n",
        "# Each image is 256x256 in size\n",
        "IMG_WIDTH = 256\n",
        "IMG_HEIGHT = 256"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rwwYQpu9FzDu"
      },
      "outputs": [],
      "source": [
        "def resize(input_image, real_image, height, width):\n",
        "  input_image = tf.image.resize(input_image, [height, width],\n",
        "                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "  real_image = tf.image.resize(real_image, [height, width],\n",
        "                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Yn3IwqhiIszt"
      },
      "outputs": [],
      "source": [
        "def random_crop(input_image, real_image):\n",
        "  stacked_image = tf.stack([input_image, real_image], axis=0)\n",
        "  cropped_image = tf.image.random_crop(\n",
        "      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n",
        "\n",
        "  return cropped_image[0], cropped_image[1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "muhR2cgbLKWW"
      },
      "outputs": [],
      "source": [
        "# Normalizing the images to [-1, 1]\n",
        "def normalize(input_image, real_image):\n",
        "  input_image = (input_image / 127.5) - 1\n",
        "  real_image = (real_image / 127.5) - 1\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fVQOjcPVLrUc"
      },
      "outputs": [],
      "source": [
        "@tf.function()\n",
        "def random_jitter(input_image, real_image):\n",
        "  # Resizing to 286x286\n",
        "  input_image, real_image = resize(input_image, real_image, 286, 286)\n",
        "\n",
        "  # Random cropping back to 256x256\n",
        "  input_image, real_image = random_crop(input_image, real_image)\n",
        "\n",
        "  if tf.random.uniform(()) > 0.5:\n",
        "    # Random mirroring\n",
        "    input_image = tf.image.flip_left_right(input_image)\n",
        "    real_image = tf.image.flip_left_right(real_image)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wfAQbzy799UV"
      },
      "source": [
        "You can inspect some of the preprocessed output:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n0OGdi6D92kM"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(6, 6))\n",
        "for i in range(4):\n",
        "  rj_inp, rj_re = random_jitter(inp, re)\n",
        "  plt.subplot(2, 2, i + 1)\n",
        "  plt.imshow(rj_inp / 255.0)\n",
        "  plt.axis('off')\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3E9LGq3WBmsh"
      },
      "source": [
        "Having checked that the loading and preprocessing works, let's define a couple of helper functions that load and preprocess the training and test sets:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tyaP4hLJ8b4W"
      },
      "outputs": [],
      "source": [
        "def load_image_train(image_file):\n",
        "  input_image, real_image = load(image_file)\n",
        "  input_image, real_image = random_jitter(input_image, real_image)\n",
        "  input_image, real_image = normalize(input_image, real_image)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VB3Z6D_zKSru"
      },
      "outputs": [],
      "source": [
        "def load_image_test(image_file):\n",
        "  input_image, real_image = load(image_file)\n",
        "  input_image, real_image = resize(input_image, real_image,\n",
        "                                   IMG_HEIGHT, IMG_WIDTH)\n",
        "  input_image, real_image = normalize(input_image, real_image)\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PIGN6ouoQxt3"
      },
      "source": [
        "## Build an input pipeline with `tf.data`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SQHmYSmk8b4b"
      },
      "outputs": [],
      "source": [
        "train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')\n",
        "train_dataset = train_dataset.map(load_image_train,\n",
        "                                  num_parallel_calls=tf.data.AUTOTUNE)\n",
        "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n",
        "train_dataset = train_dataset.batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MS9J0yA58b4g"
      },
      "outputs": [],
      "source": [
        "test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg')\n",
        "test_dataset = test_dataset.map(load_image_test)\n",
        "test_dataset = test_dataset.batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "THY-sZMiQ4UV"
      },
      "source": [
        "## Build the generator\n",
        "\n",
        "The generator of your pix2pix cGAN is a _modified_ [U-Net](https://arxiv.org/abs/1505.04597). A U-Net consists of an encoder (downsampler) and decoder (upsampler). (You can find out more about it in the [Image segmentation](https://www.tensorflow.org/tutorials/images/segmentation) tutorial and on the [U-Net project website](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/).)\n",
        "\n",
        "- Each block in the encoder is: Convolution -> Batch normalization -> Leaky ReLU\n",
        "- Each block in the decoder is: Transposed convolution -> Batch normalization -> Dropout (applied to the first 3 blocks) -> ReLU\n",
        "- There are skip connections between the encoder and decoder (as in the U-Net)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4MQPuBCgtldI"
      },
      "source": [
        "Define the downsampler (encoder):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tqqvWxlw8b4l"
      },
      "outputs": [],
      "source": [
        "OUTPUT_CHANNELS = 3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3R09ATE_SH9P"
      },
      "outputs": [],
      "source": [
        "def downsample(filters, size, apply_batchnorm=True):\n",
        "  initializer = tf.random_normal_initializer(0., 0.02)\n",
        "\n",
        "  result = tf.keras.Sequential()\n",
        "  result.add(\n",
        "      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',\n",
        "                             kernel_initializer=initializer, use_bias=False))\n",
        "\n",
        "  if apply_batchnorm:\n",
        "    result.add(tf.keras.layers.BatchNormalization())\n",
        "\n",
        "  result.add(tf.keras.layers.LeakyReLU())\n",
        "\n",
        "  return result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a6_uCZCppTh7"
      },
      "outputs": [],
      "source": [
        "down_model = downsample(3, 4)\n",
        "down_result = down_model(tf.expand_dims(inp, 0))\n",
        "print (down_result.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aFI_Pa52tjLl"
      },
      "source": [
        "Define the upsampler (decoder):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nhgDsHClSQzP"
      },
      "outputs": [],
      "source": [
        "def upsample(filters, size, apply_dropout=False):\n",
        "  initializer = tf.random_normal_initializer(0., 0.02)\n",
        "\n",
        "  result = tf.keras.Sequential()\n",
        "  result.add(\n",
        "    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,\n",
        "                                    padding='same',\n",
        "                                    kernel_initializer=initializer,\n",
        "                                    use_bias=False))\n",
        "\n",
        "  result.add(tf.keras.layers.BatchNormalization())\n",
        "\n",
        "  if apply_dropout:\n",
        "      result.add(tf.keras.layers.Dropout(0.5))\n",
        "\n",
        "  result.add(tf.keras.layers.ReLU())\n",
        "\n",
        "  return result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mz-ahSdsq0Oc"
      },
      "outputs": [],
      "source": [
        "up_model = upsample(3, 4)\n",
        "up_result = up_model(down_result)\n",
        "print (up_result.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ueEJyRVrtZ-p"
      },
      "source": [
        "Define the generator with the downsampler and the upsampler:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFPI4Nu-8b4q"
      },
      "outputs": [],
      "source": [
        "def Generator():\n",
        "  inputs = tf.keras.layers.Input(shape=[256, 256, 3])\n",
        "\n",
        "  down_stack = [\n",
        "    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)\n",
        "    downsample(128, 4),  # (batch_size, 64, 64, 128)\n",
        "    downsample(256, 4),  # (batch_size, 32, 32, 256)\n",
        "    downsample(512, 4),  # (batch_size, 16, 16, 512)\n",
        "    downsample(512, 4),  # (batch_size, 8, 8, 512)\n",
        "    downsample(512, 4),  # (batch_size, 4, 4, 512)\n",
        "    downsample(512, 4),  # (batch_size, 2, 2, 512)\n",
        "    downsample(512, 4),  # (batch_size, 1, 1, 512)\n",
        "  ]\n",
        "\n",
        "  up_stack = [\n",
        "    upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)\n",
        "    upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)\n",
        "    upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)\n",
        "    upsample(512, 4),  # (batch_size, 16, 16, 1024)\n",
        "    upsample(256, 4),  # (batch_size, 32, 32, 512)\n",
        "    upsample(128, 4),  # (batch_size, 64, 64, 256)\n",
        "    upsample(64, 4),  # (batch_size, 128, 128, 128)\n",
        "  ]\n",
        "\n",
        "  initializer = tf.random_normal_initializer(0., 0.02)\n",
        "  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,\n",
        "                                         strides=2,\n",
        "                                         padding='same',\n",
        "                                         kernel_initializer=initializer,\n",
        "                                         activation='tanh')  # (batch_size, 256, 256, 3)\n",
        "\n",
        "  x = inputs\n",
        "\n",
        "  # Downsampling through the model\n",
        "  skips = []\n",
        "  for down in down_stack:\n",
        "    x = down(x)\n",
        "    skips.append(x)\n",
        "\n",
        "  skips = reversed(skips[:-1])\n",
        "\n",
        "  # Upsampling and establishing the skip connections\n",
        "  for up, skip in zip(up_stack, skips):\n",
        "    x = up(x)\n",
        "    x = tf.keras.layers.Concatenate()([x, skip])\n",
        "\n",
        "  x = last(x)\n",
        "\n",
        "  return tf.keras.Model(inputs=inputs, outputs=x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z4PKwrcQFYvF"
      },
      "source": [
        "Visualize the generator model architecture:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dIbRPFzjmV85"
      },
      "outputs": [],
      "source": [
        "generator = Generator()\n",
        "tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z8kbgTK8FcPo"
      },
      "source": [
        "Test the generator:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U1N1_obwtdQH"
      },
      "outputs": [],
      "source": [
        "gen_output = generator(inp[tf.newaxis, ...], training=False)\n",
        "plt.imshow(gen_output[0, ...])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dpDPEQXIAiQO"
      },
      "source": [
        "### Define the generator loss\n",
        "\n",
        "GANs learn a loss that adapts to the data, while cGANs learn a structured loss that penalizes a possible structure that differs from the network output and the target image, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).\n",
        "\n",
        "- The generator loss is a sigmoid cross-entropy loss of the generated images and an **array of ones**.\n",
        "- The pix2pix paper also mentions the L1 loss, which is a MAE (mean absolute error) between the generated image and the target image.\n",
        "- This allows the generated image to become structurally similar to the target image.\n",
        "- The formula to calculate the total generator loss is `gan_loss + LAMBDA * l1_loss`, where `LAMBDA = 100`. This value was decided by the authors of the paper."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cyhxTuvJyIHV"
      },
      "outputs": [],
      "source": [
        "LAMBDA = 100"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q1Xbz5OaLj5C"
      },
      "outputs": [],
      "source": [
        "loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "90BIcCKcDMxz"
      },
      "outputs": [],
      "source": [
        "def generator_loss(disc_generated_output, gen_output, target):\n",
        "  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)\n",
        "\n",
        "  # Mean absolute error\n",
        "  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
        "\n",
        "  total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n",
        "\n",
        "  return total_gen_loss, gan_loss, l1_loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fSZbDgESHIV6"
      },
      "source": [
        "The training procedure for the generator is as follows:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TlB-XMY5Awj9"
      },
      "source": [
        "![Generator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gen.png?raw=1)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZTKZfoaoEF22"
      },
      "source": [
        "## Build the discriminator\n",
        "\n",
        "The discriminator in the pix2pix cGAN is a convolutional PatchGAN classifier—it tries to classify if each image _patch_ is real or not real, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).\n",
        "\n",
        "- Each block in the discriminator is: Convolution -> Batch normalization -> Leaky ReLU.\n",
        "- The shape of the output after the last layer is `(batch_size, 30, 30, 1)`.\n",
        "- Each `30 x 30` image patch of the output classifies a `70 x 70` portion of the input image.\n",
        "- The discriminator receives 2 inputs: \n",
        "    - The input image and the target image, which it should classify as real.\n",
        "    - The input image and the generated image (the output of the generator), which it should classify as fake.\n",
        "    - Use `tf.concat([inp, tar], axis=-1)` to concatenate these 2 inputs together."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XIuTeGL5v45m"
      },
      "source": [
        "Let's define the discriminator:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ll6aNeQx8b4v"
      },
      "outputs": [],
      "source": [
        "def Discriminator():\n",
        "  initializer = tf.random_normal_initializer(0., 0.02)\n",
        "\n",
        "  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')\n",
        "  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')\n",
        "\n",
        "  x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 256, 256, channels*2)\n",
        "\n",
        "  down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)\n",
        "  down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)\n",
        "  down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)\n",
        "\n",
        "  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)\n",
        "  conv = tf.keras.layers.Conv2D(512, 4, strides=1,\n",
        "                                kernel_initializer=initializer,\n",
        "                                use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)\n",
        "\n",
        "  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)\n",
        "\n",
        "  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)\n",
        "\n",
        "  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)\n",
        "\n",
        "  last = tf.keras.layers.Conv2D(1, 4, strides=1,\n",
        "                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)\n",
        "\n",
        "  return tf.keras.Model(inputs=[inp, tar], outputs=last)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HdV9yAbBHNkg"
      },
      "source": [
        "Visualize the discriminator model architecture:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YHoUui4om-Ev"
      },
      "outputs": [],
      "source": [
        "discriminator = Discriminator()\n",
        "tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ps7nIHigHYc7"
      },
      "source": [
        "Test the discriminator:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gDkA05NE6QMs"
      },
      "outputs": [],
      "source": [
        "disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)\n",
        "plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')\n",
        "plt.colorbar()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AOqg1dhUAWoD"
      },
      "source": [
        "### Define the discriminator loss\n",
        "\n",
        "- The `discriminator_loss` function takes 2 inputs: **real images** and **generated images**.\n",
        "- `real_loss` is a sigmoid cross-entropy loss of the **real images** and an **array of ones(since these are the real images)**.\n",
        "- `generated_loss` is a sigmoid cross-entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**.\n",
        "- The `total_loss` is the sum of `real_loss` and `generated_loss`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wkMNfBWlT-PV"
      },
      "outputs": [],
      "source": [
        "def discriminator_loss(disc_real_output, disc_generated_output):\n",
        "  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)\n",
        "\n",
        "  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)\n",
        "\n",
        "  total_disc_loss = real_loss + generated_loss\n",
        "\n",
        "  return total_disc_loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ede4p2YELFa"
      },
      "source": [
        "The training procedure for the discriminator is shown below.\n",
        "\n",
        "To learn more about the architecture and the hyperparameters you can refer to the [pix2pix paper](https://arxiv.org/abs/1611.07004)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IS9sHa-1BoAF"
      },
      "source": [
        "![Discriminator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/dis.png?raw=1)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0FMYgY_mPfTi"
      },
      "source": [
        "## Define the optimizers and a checkpoint-saver\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lbHFNexF0x6O"
      },
      "outputs": [],
      "source": [
        "generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n",
        "discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WJnftd5sQsv6"
      },
      "outputs": [],
      "source": [
        "checkpoint_dir = './training_checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
        "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n",
        "                                 discriminator_optimizer=discriminator_optimizer,\n",
        "                                 generator=generator,\n",
        "                                 discriminator=discriminator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rw1fkAczTQYh"
      },
      "source": [
        "## Generate images\n",
        "\n",
        "Write a function to plot some images during training.\n",
        "\n",
        "- Pass images from the test set to the generator.\n",
        "- The generator will then translate the input image into the output.\n",
        "- The last step is to plot the predictions and _voila_!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rb0QQFHF-JfS"
      },
      "source": [
        "Note: The `training=True` is intentional here since\n",
        "you want the batch statistics, while running the model on the test dataset. If you use `training=False`, you get the accumulated statistics learned from the training dataset (which you don't want)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RmdVsmvhPxyy"
      },
      "outputs": [],
      "source": [
        "def generate_images(model, test_input, tar):\n",
        "  prediction = model(test_input, training=True)\n",
        "  plt.figure(figsize=(15, 15))\n",
        "\n",
        "  display_list = [test_input[0], tar[0], prediction[0]]\n",
        "  title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
        "\n",
        "  for i in range(3):\n",
        "    plt.subplot(1, 3, i+1)\n",
        "    plt.title(title[i])\n",
        "    # Getting the pixel values in the [0, 1] range to plot.\n",
        "    plt.imshow(display_list[i] * 0.5 + 0.5)\n",
        "    plt.axis('off')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gipsSEoZIG1a"
      },
      "source": [
        "Test the function:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Fc4NzT-DgEx"
      },
      "outputs": [],
      "source": [
        "for example_input, example_target in test_dataset.take(1):\n",
        "  generate_images(generator, example_input, example_target)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NLKOG55MErD0"
      },
      "source": [
        "## Training\n",
        "\n",
        "- For each example input generates an output.\n",
        "- The discriminator receives the `input_image` and the generated image as the first input. The second input is the `input_image` and the `target_image`.\n",
        "- Next, calculate the generator and the discriminator loss.\n",
        "- Then, calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n",
        "- Finally, log the losses to TensorBoard."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NS2GWywBbAWo"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 150"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xNNMDBNH12q-"
      },
      "outputs": [],
      "source": [
        "log_dir=\"logs/\"\n",
        "\n",
        "summary_writer = tf.summary.create_file_writer(\n",
        "  log_dir + \"fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KBKUV2sKXDbY"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(input_image, target, epoch):\n",
        "  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
        "    gen_output = generator(input_image, training=True)\n",
        "\n",
        "    disc_real_output = discriminator([input_image, target], training=True)\n",
        "    disc_generated_output = discriminator([input_image, gen_output], training=True)\n",
        "\n",
        "    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)\n",
        "    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n",
        "\n",
        "  generator_gradients = gen_tape.gradient(gen_total_loss,\n",
        "                                          generator.trainable_variables)\n",
        "  discriminator_gradients = disc_tape.gradient(disc_loss,\n",
        "                                               discriminator.trainable_variables)\n",
        "\n",
        "  generator_optimizer.apply_gradients(zip(generator_gradients,\n",
        "                                          generator.trainable_variables))\n",
        "  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,\n",
        "                                              discriminator.trainable_variables))\n",
        "\n",
        "  with summary_writer.as_default():\n",
        "    tf.summary.scalar('gen_total_loss', gen_total_loss, step=epoch)\n",
        "    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=epoch)\n",
        "    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=epoch)\n",
        "    tf.summary.scalar('disc_loss', disc_loss, step=epoch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hx7s-vBHFKdh"
      },
      "source": [
        "The actual training loop:\n",
        "\n",
        "- Iterates over the number of epochs.\n",
        "- At each epoch: clears the display and runs `generate_images` to show the progress.\n",
        "- At each epoch: iterates over the training dataset, printing a dot (`.`) to indicate progress for each example.\n",
        "- Every 20 epochs: saves a checkpoint."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2M7LmLtGEMQJ"
      },
      "outputs": [],
      "source": [
        "def fit(train_ds, epochs, test_ds):\n",
        "  for epoch in range(epochs):\n",
        "    start = time.time()\n",
        "\n",
        "    display.clear_output(wait=True)\n",
        "\n",
        "    for example_input, example_target in test_ds.take(1):\n",
        "      generate_images(generator, example_input, example_target)\n",
        "    print(\"Epoch: \", epoch)\n",
        "\n",
        "    # Training step\n",
        "    for n, (input_image, target) in train_ds.enumerate():\n",
        "      print('.', end='')\n",
        "      if (n+1) % 100 == 0:\n",
        "        print()\n",
        "      train_step(input_image, target, epoch)\n",
        "    print()\n",
        "\n",
        "    # Saving (checkpointing) the model every 20 epochs\n",
        "    if (epoch + 1) % 20 == 0:\n",
        "      checkpoint.save(file_prefix=checkpoint_prefix)\n",
        "\n",
        "    print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
        "                                                        time.time()-start))\n",
        "  checkpoint.save(file_prefix=checkpoint_prefix)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wozqyTh2wmCu"
      },
      "source": [
        "This training loop saves logs that you can view in TensorBoard to monitor the training progress.\n",
        "\n",
        "If you work on a local machine, you would launch a separate TensorBoard process. When working in a notebook, launch the viewer before starting the training to monitor with TensorBoard.\n",
        "\n",
        "To launch the viewer paste the following into a code-cell:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ot22ujrlLhOd"
      },
      "outputs": [],
      "source": [
        "#docs_infra: no_execute\n",
        "%load_ext tensorboard\n",
        "%tensorboard --logdir {log_dir}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pe0-8Bzg22ox"
      },
      "source": [
        "Finally, run the training loop:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a1zZmKmvOH85"
      },
      "outputs": [],
      "source": [
        "fit(train_dataset, EPOCHS, test_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oeq9sByu86-B"
      },
      "source": [
        "If you want to share the TensorBoard results _publicly_, you can upload the logs to [TensorBoard.dev](https://tensorboard.dev/) by copying the following into a code-cell.\n",
        "\n",
        "Note: This requires a Google account.\n",
        "\n",
        "```\n",
        "!tensorboard dev upload --logdir {log_dir}\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l-kT7WHRKz-E"
      },
      "source": [
        "Caution: This command does not terminate. It's designed to continuously upload the results of long-running experiments. Once your data is uploaded you need to stop it using the \"interrupt execution\" option in your notebook tool."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-lGhS_LfwQoL"
      },
      "source": [
        "You can view the [results of a previous run](https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw) of this notebook on [TensorBoard.dev](https://tensorboard.dev/).\n",
        "\n",
        "TensorBoard.dev is a managed experience for hosting, tracking, and sharing ML experiments with everyone.\n",
        "\n",
        "It can also included inline using an `<iframe>`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8IS4c93guQ8E"
      },
      "outputs": [],
      "source": [
        "display.IFrame(\n",
        "    src=\"https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw\",\n",
        "    width=\"100%\",\n",
        "    height=\"1000px\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DMTm4peo3cem"
      },
      "source": [
        "Interpreting the logs is more subtle when training a GAN (or a cGAN like pix2pix) compared to a simple classification or regression model. Things to look for:\n",
        "\n",
        "- Check that neither the generator nor the discriminator model has \"won\". If either the `gen_gan_loss` or the `disc_loss` gets very low, it's an indicator that this model is dominating the other, and you are not successfully training the combined model.\n",
        "- The value `log(2) = 0.69` is a good reference point for these losses, as it indicates a perplexity of 2 - the discriminator is, on average, equally uncertain about the two options.\n",
        "- For the `disc_loss`, a value below `0.69` means the discriminator is doing better than random on the combined set of real and generated images.\n",
        "- For the `gen_gan_loss`, a value below `0.69` means the generator is doing better than random at fooling the discriminator.\n",
        "- As training progresses, the `gen_l1_loss` should go down."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kz80bY3aQ1VZ"
      },
      "source": [
        "## Restore the latest checkpoint and test the network"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HSSm4kfvJiqv"
      },
      "outputs": [],
      "source": [
        "!ls {checkpoint_dir}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4t4x69adQ5xb"
      },
      "outputs": [],
      "source": [
        "# Restoring the latest checkpoint in checkpoint_dir\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1RGysMU_BZhx"
      },
      "source": [
        "## Generate some images using the test set"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KUgSnmy2nqSP"
      },
      "outputs": [],
      "source": [
        "# Run the trained model on a few examples from the test set\n",
        "for inp, tar in test_dataset.take(5):\n",
        "  generate_images(generator, inp, tar)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "pix2pix.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
